{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to add persistence (\"memory\") to your graph\n",
    "\n",
    "Many AI applications need memory to share context across multiple interactions. In LangGraph4j, memory is provided for any [`StateGraph`] through [`Checkpointers`].\n",
    "\n",
    "When creating any LangGraph workflow, you can set them up to persist their state by doing using the following:\n",
    "\n",
    "1. A [`Checkpointer`], such as the [`MemorySaver`]\n",
    "1. Pass your [`Checkpointers`] in configuration when compiling the graph.\n",
    "\n",
    "\n",
    "[`StateGraph`]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/StateGraph.html\n",
    "[`Checkpointers`]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.html\n",
    "[`Checkpointer`]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/checkpoint/Checkpoint.html\n",
    "[`MemorySaver`]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/checkpoint/MemorySaver.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "var userHomeDir = System.getProperty(\"user.home\");\n",
    "var localRespoUrl = \"file://\" + userHomeDir + \"/.m2/repository/\";\n",
    "var langchain4jVersion = \"1.0.1\";\n",
    "var langchain4jbeta = \"1.0.1-beta6\";\n",
    "var langgraph4jVersion = \"1.6-SNAPSHOT\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%dependency /add-repo local \\{localRespoUrl} release|never snapshot|always\n",
    "// %dependency /list-repos\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-core:\\{langgraph4jVersion}\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-langchain4j:\\{langgraph4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j:\\{langchain4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j-open-ai:\\{langchain4jVersion}\n",
    "\n",
    "%dependency /resolve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize Logger**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "try( var file = new java.io.FileInputStream(\"./logging.properties\")) {\n",
    "    java.util.logging.LogManager.getLogManager().readConfiguration( file );\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the state\n",
    "\n",
    "State is an (immutable) data class, inheriting from prebuilt [MessagesState], shared with all nodes in our graph. A state is basically a wrapper of a `Map<String,Object>` that provides some enhancers:\n",
    "\n",
    "1. Schema (optional), that is a `Map<String,Channel>` where each [`Channel`] describe behaviour of the related property\n",
    "1. `value()` accessors that inspect Map an return an Optional of value contained and cast to the required type\n",
    "\n",
    "[`Channel`]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/state/Channel.html\n",
    "[MessagesState]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/prebuilt/MessagesState.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.bsc.langgraph4j.prebuilt.MessagesState;\n",
    "import org.bsc.langgraph4j.state.Channel;\n",
    "import dev.langchain4j.data.message.AiMessage;\n",
    "import dev.langchain4j.data.message.ChatMessage;\n",
    "\n",
    "public class MessageState extends MessagesState<ChatMessage> {\n",
    "\n",
    "    public MessageState(Map<String, Object> initData) {\n",
    "        super( initData  );\n",
    "    }\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Register Serializers\n",
    "\n",
    "Every object that should be stored into State **MUST BE SERIALIZABLE**. If the object is not `Serializable` by default, Langgraph4j provides a way to build and associate a custom [Serializer] to it. \n",
    "\n",
    "In the example, we use  [`Serializer`] for [ToolExecutionRequest] and [ChatMesssager] provided by langgraph4j integration module with langchain4j .\n",
    "\n",
    "[Serializer]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/serializer/Serializer.html\n",
    "[ChatMesssager]: https://docs.langchain4j.dev/apidocs/dev/langchain4j/data/message/ChatMesssager.html\n",
    "[ToolExecutionRequest]: https://docs.langchain4j.dev/apidocs/dev/langchain4j/data/message/ToolExecutionRequest.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "SLF4J: No SLF4J providers were found.\n",
      "SLF4J: Defaulting to no-operation (NOP) logger implementation\n",
      "SLF4J: See https://www.slf4j.org/codes.html#noProviders for further details.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "SerializerMapper: \n",
       "java.util.Map\n",
       "java.util.Collection\n",
       "dev.langchain4j.agent.tool.ToolExecutionRequest\n",
       "dev.langchain4j.data.message.ChatMessage"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import dev.langchain4j.data.message.AiMessage;\n",
    "import dev.langchain4j.data.message.SystemMessage;\n",
    "import dev.langchain4j.data.message.UserMessage;\n",
    "import dev.langchain4j.data.message.ToolExecutionResultMessage;\n",
    "import dev.langchain4j.agent.tool.ToolExecutionRequest;\n",
    "import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;\n",
    "import org.bsc.langgraph4j.langchain4j.serializer.std.ChatMesssageSerializer;\n",
    "import org.bsc.langgraph4j.langchain4j.serializer.std.ToolExecutionRequestSerializer;\n",
    "import org.bsc.langgraph4j.state.AgentStateFactory;\n",
    "\n",
    "var stateSerializer = new ObjectStreamStateSerializer<MessageState>( MessageState::new );\n",
    "stateSerializer.mapper()\n",
    "    // Setup custom serializer for Langchain4j ToolExecutionRequest\n",
    "    .register(ToolExecutionRequest.class, new ToolExecutionRequestSerializer() )\n",
    "    // Setup custom serializer for Langchain4j AiMessage\n",
    "    .register(ChatMessage.class, new ChatMesssageSerializer() )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the tools\n",
    "\n",
    "Using [langchain4j], We will first define the tools we want to use. For this simple example, we will\n",
    "use create a placeholder search engine. However, it is really easy to create\n",
    "your own tools - see documentation\n",
    "[here][tools] on how to do\n",
    "that.\n",
    "\n",
    "[langchain4j]: https://docs.langchain4j.dev\n",
    "[tools]: https://docs.langchain4j.dev/tutorials/tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.agent.tool.P;\n",
    "import dev.langchain4j.agent.tool.Tool;\n",
    "\n",
    "import java.util.Optional;\n",
    "\n",
    "import static java.lang.String.format;\n",
    "\n",
    "public class SearchTool {\n",
    "\n",
    "    @Tool(\"Use to surf the web, fetch current information, check the weather, and retrieve other information.\")\n",
    "    String execQuery(@P(\"The query to use in your search.\") String query) {\n",
    "\n",
    "        // This is a placeholder for the actual implementation\n",
    "        return \"Cold, with a low of 13 degrees\";\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the model\n",
    "\n",
    "Now we will load the\n",
    "[chat model].\n",
    "\n",
    "1. It should work with messages. We will represent all agent state in the form of messages, so it needs to be able to work well with them.\n",
    "2. It should work with [tool calling],meaning it can return function arguments in its response.\n",
    "\n",
    "Note:\n",
    "   >\n",
    "   > These model requirements are not general requirements for using LangGraph4j - they are just requirements for this one example.\n",
    "   >\n",
    "\n",
    "[chat model]: https://docs.langchain4j.dev/tutorials/chat-and-language-models\n",
    "[tool calling]: https://docs.langchain4j.dev/tutorials/tools   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "import dev.langchain4j.agent.tool.ToolSpecification;\n",
    "import dev.langchain4j.agent.tool.ToolSpecifications;\n",
    "\n",
    "var model = OpenAiChatModel.builder()\n",
    "                .apiKey( System.getenv(\"OPENAI_API_KEY\") )\n",
    "                .modelName( \"gpt-4o\" )\n",
    "                .logResponses(true)\n",
    "                .maxRetries(2)\n",
    "                .temperature(0.0)\n",
    "                .maxTokens(2000)\n",
    "                .build()   \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test function calling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AiMessage { text = \"I'm unable to provide real-time weather forecasts. For the most accurate and up-to-date weather information, please check a reliable weather website or app.\" toolExecutionRequests = [] }"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import dev.langchain4j.agent.tool.ToolSpecification;\n",
    "import dev.langchain4j.agent.tool.ToolSpecifications;\n",
    "import dev.langchain4j.data.message.UserMessage;\n",
    "import dev.langchain4j.data.message.AiMessage;\n",
    "import dev.langchain4j.model.output.Response;\n",
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "import dev.langchain4j.model.chat.request.ChatRequest;\n",
    "\n",
    "var tools = ToolSpecifications.toolSpecificationsFrom( SearchTool.class );\n",
    "\n",
    "var userMessage = UserMessage.from(\"What will the weather be like in London tomorrow?\");\n",
    "\n",
    "var request = ChatRequest.builder()\n",
    "                .messages( userMessage )\n",
    "                .build();\n",
    "\n",
    "var result = model.chat(request );\n",
    "\n",
    "result.aiMessage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the graph\n",
    "\n",
    "We can now put it all together. We will run it first without a checkpointer:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import static org.bsc.langgraph4j.StateGraph.START;\n",
    "import static org.bsc.langgraph4j.StateGraph.END;\n",
    "import org.bsc.langgraph4j.StateGraph;\n",
    "import org.bsc.langgraph4j.action.EdgeAction;\n",
    "import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;\n",
    "import org.bsc.langgraph4j.action.NodeAction;\n",
    "import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;\n",
    "import dev.langchain4j.data.message.AiMessage;\n",
    "import dev.langchain4j.data.message.ChatMessage;\n",
    "import dev.langchain4j.service.tool.DefaultToolExecutor;\n",
    "\n",
    "// Route Message \n",
    "EdgeAction<MessageState> routeMessage = state -> {\n",
    "  \n",
    "  var lastMessage = state.lastMessage();\n",
    "  \n",
    "  if ( !lastMessage.isPresent()) return \"exit\";\n",
    "\n",
    "  var message = (AiMessage)lastMessage.get();\n",
    "\n",
    "  // If tools should be called\n",
    "  if ( message.hasToolExecutionRequests() ) return \"next\";\n",
    "  \n",
    "  // If no tools are called, we can finish (respond to the user)\n",
    "  return \"exit\";\n",
    "};\n",
    "\n",
    "// Call Model\n",
    "NodeAction<MessageState> callModel = state -> {\n",
    "  \n",
    "\n",
    "  var request = ChatRequest.builder()\n",
    "                .messages( state.messages() )\n",
    "                .build();\n",
    "\n",
    "  var response = model.chat(request );\n",
    "\n",
    "  return Map.of( \"messages\", response.aiMessage() );\n",
    "};\n",
    "\n",
    "var searchTool = new SearchTool();\n",
    "\n",
    "\n",
    "// Invoke Tool \n",
    "NodeAction<MessageState> invokeTool = state -> {\n",
    "\n",
    "  var lastMessage = (AiMessage)state.lastMessage()\n",
    "                          .orElseThrow( () -> ( new IllegalStateException( \"last message not found!\")) );\n",
    "\n",
    "  var executionRequest = lastMessage.toolExecutionRequests().get(0);\n",
    "\n",
    "  var executor = new DefaultToolExecutor( searchTool, executionRequest );\n",
    "\n",
    "  var result = executor.execute( executionRequest, null );\n",
    "\n",
    "  return Map.of( \"messages\", AiMessage.from( result ) );\n",
    "};\n",
    "\n",
    "// Define Graph\n",
    "\n",
    "var workflow = new StateGraph<MessageState> ( MessageState.SCHEMA, stateSerializer )\n",
    "  .addNode(\"agent\", node_async(callModel) )\n",
    "  .addNode(\"tools\", node_async(invokeTool) )\n",
    "  .addEdge(START, \"agent\")\n",
    "  .addConditionalEdges(\"agent\", edge_async(routeMessage), Map.of( \"next\", \"tools\", \"exit\", END ))\n",
    "  .addEdge(\"tools\", \"agent\");\n",
    "\n",
    "var graph = workflow.compile();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__START__\n",
      "agent\n",
      "{messages=[AiMessage { text = \"Hi I'm Bartolo, niced to meet you.\" toolExecutionRequests = [] }, AiMessage { text = \"Hello Bartolo, nice to meet you too! How can I assist you today?\" toolExecutionRequests = [] }]}\n",
      "__END__\n"
     ]
    }
   ],
   "source": [
    "\n",
    "Map<String,Object> inputs = Map.of( \"messages\", AiMessage.from(\"Hi I'm Bartolo, niced to meet you.\") );\n",
    "\n",
    "var result = graph.stream( inputs );\n",
    "\n",
    "for( var r : result ) {\n",
    "  System.out.println( r.node() );\n",
    "  if( r.node().equals(\"agent\")) {\n",
    "    System.out.println( r.state() );\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__START__\n",
      "agent\n",
      "{messages=[AiMessage { text = \"Remember my name?\" toolExecutionRequests = [] }, AiMessage { text = \"I'm sorry, but I don't have the ability to remember personal information or previous interactions. How can I assist you today?\" toolExecutionRequests = [] }]}\n",
      "__END__\n"
     ]
    }
   ],
   "source": [
    "\n",
    "Map<String,Object> inputs = Map.of( \"messages\", AiMessage.from(\"Remember my name?\") );\n",
    "\n",
    "var result = graph.stream( inputs );\n",
    "\n",
    "for( var r : result ) {\n",
    "  System.out.println( r.node() );\n",
    "  if( r.node().equals(\"agent\")) {\n",
    "    System.out.println( r.state() );\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Memory\n",
    "\n",
    "Let's try it again with a checkpointer. We will use the\n",
    "[`MemorySaver`],\n",
    "which will \"save\" checkpoints in-memory.\n",
    "\n",
    "[`MemorySaver`]: https://bsorrentino.github.io/langgraph4j/apidocs/org/bsc/langgraph4j/checkpoint/MemorySaver.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.bsc.langgraph4j.checkpoint.MemorySaver; \n",
    "import org.bsc.langgraph4j.CompileConfig; \n",
    "\n",
    "// Here we only save in-memory\n",
    "var memory = new MemorySaver();\n",
    "\n",
    "var compileConfig = CompileConfig.builder()\n",
    "                    .checkpointSaver(memory)\n",
    "                    .build();\n",
    "\n",
    "var persistentGraph = workflow.compile( compileConfig );"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__START__\n",
      "agent\n",
      "AiMessage { text = \"Hello Bartolo, nice to meet you too! How can I assist you today?\" toolExecutionRequests = [] }\n",
      "__END__\n"
     ]
    }
   ],
   "source": [
    "import org.bsc.langgraph4j.RunnableConfig;\n",
    "\n",
    "var runnableConfig =  RunnableConfig.builder()\n",
    "                .threadId(\"conversation-num-1\" )\n",
    "                .build();\n",
    "\n",
    "Map<String,Object> inputs = Map.of( \"messages\", AiMessage.from(\"Hi I'm Bartolo, niced to meet you.\") );\n",
    "\n",
    "var result = persistentGraph.stream( inputs, runnableConfig );\n",
    "\n",
    "for( var r : result ) {\n",
    "  System.out.println( r.node() );\n",
    "  if( r.node().equals(\"agent\")) {\n",
    "    System.out.println( r.state().lastMessage().orElse(null) );\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__START__\n",
      "agent\n",
      "AiMessage { text = \"Yes, your name is Bartolo. How can I assist you today?\" toolExecutionRequests = [] }\n",
      "__END__\n"
     ]
    }
   ],
   "source": [
    "\n",
    "Map<String,Object> inputs = Map.of( \"messages\", AiMessage.from(\"Remember my name?\") );\n",
    "\n",
    "var result = persistentGraph.stream( inputs, runnableConfig );\n",
    "\n",
    "for( var r : result ) {\n",
    "  System.out.println( r.node() );\n",
    "  if( r.node().equals(\"agent\")) {\n",
    "    System.out.println( r.state().lastMessage().orElse(null) );\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## New Conversational Thread\n",
    "\n",
    "If we want to start a new conversation, we can pass in a different\n",
    "**`thread_id`**. Poof! All the memories are gone (just kidding, they'll always\n",
    "live in that other thread)!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "runnableConfig =  RunnableConfig.builder()\n",
    "                .threadId(\"conversation-2\" )\n",
    "                .build();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__START__\n",
      "agent\n",
      "AiMessage { text = \"No, I don't know your name. I don't have access to personal data about individuals unless it has been shared with me in the course of our conversation. If you have any questions or need assistance, feel free to ask!\" toolExecutionRequests = [] }\n",
      "__END__\n"
     ]
    }
   ],
   "source": [
    "inputs = Map.of( \"messages\", AiMessage.from(\"you know my name?\") );\n",
    "\n",
    "var result = persistentGraph.stream( inputs, runnableConfig );\n",
    "\n",
    "for( var r : result ) {\n",
    "  System.out.println( r.node() );\n",
    "  if( r.node().equals(\"agent\")) {\n",
    "    System.out.println( r.state().lastMessage().orElse(null) );\n",
    "  }\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java (rjk 2.2.0)",
   "language": "java",
   "name": "rapaio-jupyter-kernel"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "java",
   "nbconvert_exporter": "script",
   "pygments_lexer": "java",
   "version": "22.0.2+9-70"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
